{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementation of BRNN\n",
    "\n",
    "This notebook runs the groups implementation of the Bayesian RNN on the PTB dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#Import the brnn model implementation\n",
    "import brnn_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### How to run:\n",
    "\n",
    "The model can be run at three configurations. \"Small\", \"Medium\" and \"Large\".\n",
    "\n",
    "For a non GPU following approximate runtimes can be expected:\n",
    "* Approx. 4 hours for \"Small\" \n",
    "* Approx. 12 hours for \"Medium\" \n",
    "* Approx. 24 hours for \"Large\" \n",
    "\n",
    "By default the notebook runs the \"small\" configuration, with a Gaussian mixture prior, using a mixture of 0.25 and log variances -1.0 and -7.0.\n",
    "\n",
    "Simply change the variable values below to run with different settings.\n",
    "\n",
    "To access Tensorboard whilst running, simply access localhost:6006"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Model can be \"test\", \"small\", \"medium\", \"large\"\n",
    "model = \"small\"\n",
    "\n",
    "#Put the path to the data here\n",
    "data_path = \"../../data\"\n",
    "\n",
    "#Put the path to where you want to save the training data\n",
    "save_path = \"tensorboard/\"\n",
    "\n",
    "# The mixing degree for the prior gaussian mixture\n",
    "# As in Fortunato they report scanning:\n",
    "# mix_pi \\in { 1/4, 1/2, 3/4 }\n",
    "mix_pi = 0.25\n",
    "\n",
    "# As in Fortunato they report scanning\n",
    "# log sigma1 \\in { 0, -1, -2 }\n",
    "# log sigma2 \\in { -6, -7, -8 }\n",
    "log_sigma1 = -1.0\n",
    "log_sigma2 = -7.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Starting standard services.\n",
      "INFO:tensorflow:Saving checkpoint to path tensorboard/model.ckpt\n",
      "INFO:tensorflow:Starting queue runners.\n",
      "INFO:tensorflow:Model/global_step/sec: 0\n",
      "Epoch: 1 Learning rate: 1.000\n",
      "INFO:tensorflow:Recording summary at step 0.\n",
      "0.000 perplexity: 10414.302 speed: 591 wps\n",
      "KL is 126.22071838378906\n",
      "0.004 perplexity: 6523.971 speed: 3313 wps\n",
      "KL is 126.21833038330078\n",
      "0.104 perplexity: 923.538 speed: 7556 wps\n",
      "KL is 126.04281616210938\n",
      "0.204 perplexity: 682.637 speed: 7794 wps\n",
      "KL is 125.74568176269531\n",
      "0.304 perplexity: 558.201 speed: 7883 wps\n",
      "KL is 125.60706329345703\n",
      "0.404 perplexity: 484.591 speed: 7932 wps\n",
      "KL is 125.315673828125\n",
      "0.504 perplexity: 436.197 speed: 7960 wps\n",
      "KL is 125.13956451416016\n",
      "0.604 perplexity: 394.737 speed: 7976 wps\n",
      "KL is 124.90492248535156\n",
      "0.703 perplexity: 365.310 speed: 7990 wps\n",
      "KL is 124.59835052490234\n",
      "0.803 perplexity: 342.283 speed: 7998 wps\n",
      "KL is 124.41650390625\n",
      "0.903 perplexity: 321.555 speed: 8004 wps\n",
      "KL is 124.15804290771484\n",
      "Epoch: 1 Train Perplexity: 305.828\n",
      "Epoch: 1 Valid Perplexity: 201.375\n",
      "Epoch: 2 Learning rate: 1.000\n",
      "0.000 perplexity: 241.366 speed: 7408 wps\n",
      "KL is 123.90028381347656\n",
      "0.004 perplexity: 231.148 speed: 7661 wps\n",
      "KL is 123.89181518554688\n",
      "INFO:tensorflow:Model/global_step/sec: 19.7121\n",
      "INFO:tensorflow:Recording summary at step 2367.\n",
      "0.104 perplexity: 177.120 speed: 7980 wps\n",
      "KL is 123.63995361328125\n",
      "0.204 perplexity: 184.675 speed: 8011 wps\n",
      "KL is 123.31181335449219\n",
      "0.304 perplexity: 180.072 speed: 8021 wps\n",
      "KL is 123.1081314086914\n",
      "0.404 perplexity: 177.748 speed: 8040 wps\n",
      "KL is 122.81965637207031\n",
      "0.504 perplexity: 175.489 speed: 8043 wps\n",
      "KL is 122.58021545410156\n",
      "0.604 perplexity: 170.835 speed: 8045 wps\n",
      "KL is 122.30707550048828\n",
      "0.703 perplexity: 168.464 speed: 8052 wps\n",
      "KL is 122.07940673828125\n",
      "0.803 perplexity: 166.270 speed: 8055 wps\n",
      "KL is 121.75407409667969\n",
      "0.903 perplexity: 162.593 speed: 8058 wps\n",
      "KL is 121.50481414794922\n",
      "Epoch: 2 Train Perplexity: 160.320\n",
      "Epoch: 2 Valid Perplexity: 158.154\n",
      "Epoch: 3 Learning rate: 1.000\n",
      "0.000 perplexity: 210.173 speed: 7811 wps\n",
      "KL is 121.21764373779297\n",
      "0.004 perplexity: 168.688 speed: 7798 wps\n",
      "KL is 121.22041320800781\n",
      "INFO:tensorflow:Model/global_step/sec: 19.8833\n",
      "INFO:tensorflow:Recording summary at step 4753.\n",
      "0.104 perplexity: 130.445 speed: 8030 wps\n",
      "KL is 120.89619445800781\n",
      "0.204 perplexity: 140.185 speed: 8045 wps\n",
      "KL is 120.68277740478516\n",
      "0.304 perplexity: 137.528 speed: 8048 wps\n",
      "KL is 120.36227416992188\n",
      "0.404 perplexity: 137.073 speed: 8055 wps\n",
      "KL is 120.12006378173828\n",
      "0.504 perplexity: 136.322 speed: 8055 wps\n",
      "KL is 119.83710479736328\n",
      "0.604 perplexity: 133.597 speed: 8057 wps\n",
      "KL is 119.53964233398438\n",
      "0.703 perplexity: 132.943 speed: 8061 wps\n",
      "KL is 119.32041931152344\n",
      "0.803 perplexity: 132.071 speed: 8055 wps\n",
      "KL is 118.98394775390625\n",
      "0.903 perplexity: 129.718 speed: 8058 wps\n",
      "KL is 118.7263412475586\n",
      "Epoch: 3 Train Perplexity: 128.742\n",
      "Epoch: 3 Valid Perplexity: 139.949\n",
      "Epoch: 4 Learning rate: 1.000\n",
      "0.000 perplexity: 146.973 speed: 6708 wps\n",
      "KL is 118.46282958984375\n",
      "0.004 perplexity: 135.958 speed: 7866 wps\n",
      "KL is 118.46316528320312\n",
      "INFO:tensorflow:Model/global_step/sec: 19.8748\n",
      "INFO:tensorflow:Recording summary at step 7138.\n",
      "0.104 perplexity: 108.845 speed: 8056 wps\n",
      "KL is 118.19829559326172\n",
      "0.204 perplexity: 119.518 speed: 8067 wps\n",
      "KL is 117.84955596923828\n",
      "0.304 perplexity: 117.436 speed: 8079 wps\n",
      "KL is 117.58294677734375\n",
      "0.404 perplexity: 117.482 speed: 8076 wps\n",
      "KL is 117.29268646240234\n",
      "0.504 perplexity: 117.024 speed: 8076 wps\n",
      "KL is 116.99996185302734\n",
      "0.604 perplexity: 115.115 speed: 8076 wps\n",
      "KL is 116.71937561035156\n",
      "0.703 perplexity: 114.920 speed: 8080 wps\n",
      "KL is 116.47443389892578\n",
      "0.803 perplexity: 114.560 speed: 8082 wps\n",
      "KL is 116.1375732421875\n",
      "0.903 perplexity: 112.833 speed: 8084 wps\n",
      "KL is 115.87312316894531\n",
      "Epoch: 4 Train Perplexity: 112.288\n",
      "Epoch: 4 Valid Perplexity: 131.267\n",
      "Epoch: 5 Learning rate: 0.500\n",
      "0.000 perplexity: 169.776 speed: 6716 wps\n",
      "KL is 115.55931091308594\n",
      "0.004 perplexity: 118.862 speed: 7858 wps\n",
      "KL is 115.56016540527344\n",
      "INFO:tensorflow:Model/global_step/sec: 19.9252\n",
      "INFO:tensorflow:Recording summary at step 9528.\n",
      "0.104 perplexity: 93.078 speed: 8031 wps\n",
      "KL is 115.28292846679688\n",
      "0.204 perplexity: 101.175 speed: 8059 wps\n",
      "KL is 115.0841293334961\n",
      "0.304 perplexity: 98.101 speed: 8074 wps\n",
      "KL is 114.91189575195312\n",
      "0.404 perplexity: 97.555 speed: 8085 wps\n",
      "KL is 114.690673828125\n",
      "0.504 perplexity: 96.486 speed: 8082 wps\n",
      "KL is 114.55152130126953\n",
      "0.604 perplexity: 94.372 speed: 8087 wps\n",
      "KL is 114.32633972167969\n",
      "0.703 perplexity: 93.715 speed: 8090 wps\n",
      "KL is 114.11387634277344\n",
      "0.803 perplexity: 92.948 speed: 8089 wps\n",
      "KL is 114.03285217285156\n",
      "0.903 perplexity: 91.099 speed: 8088 wps\n",
      "KL is 113.87857818603516\n",
      "Epoch: 5 Train Perplexity: 90.269\n",
      "Epoch: 5 Valid Perplexity: 113.028\n",
      "Epoch: 6 Learning rate: 0.250\n",
      "0.000 perplexity: 104.128 speed: 6948 wps\n",
      "KL is 113.6947021484375\n",
      "0.004 perplexity: 100.902 speed: 7933 wps\n",
      "KL is 113.7021713256836\n",
      "0.104 perplexity: 80.053 speed: 8057 wps\n",
      "KL is 113.56903076171875\n",
      "INFO:tensorflow:Saving checkpoint to path tensorboard/model.ckpt\n",
      "INFO:tensorflow:Model/global_step/sec: 19.9334\n",
      "INFO:tensorflow:Recording summary at step 11920.\n",
      "0.204 perplexity: 87.818 speed: 7989 wps\n",
      "KL is 113.50247955322266\n",
      "0.304 perplexity: 85.173 speed: 8032 wps\n",
      "KL is 113.38402557373047\n",
      "0.404 perplexity: 84.816 speed: 8049 wps\n",
      "KL is 113.27145385742188\n",
      "0.504 perplexity: 83.921 speed: 8057 wps\n",
      "KL is 113.23677825927734\n",
      "0.604 perplexity: 82.161 speed: 8064 wps\n",
      "KL is 113.1521224975586\n",
      "0.703 perplexity: 81.623 speed: 8066 wps\n",
      "KL is 113.05086517333984\n",
      "0.803 perplexity: 80.920 speed: 8068 wps\n",
      "KL is 112.93460083007812\n",
      "0.903 perplexity: 79.218 speed: 8064 wps\n",
      "KL is 112.83402252197266\n",
      "Epoch: 6 Train Perplexity: 78.378\n",
      "Epoch: 6 Valid Perplexity: 105.654\n",
      "Epoch: 7 Learning rate: 0.125\n",
      "0.000 perplexity: 130.552 speed: 7117 wps\n",
      "KL is 112.86695861816406\n",
      "0.004 perplexity: 93.945 speed: 7851 wps\n",
      "KL is 112.81212615966797\n",
      "0.104 perplexity: 73.480 speed: 8076 wps\n",
      "KL is 112.76998901367188\n",
      "INFO:tensorflow:Model/global_step/sec: 19.8917\n",
      "INFO:tensorflow:Recording summary at step 14308.\n",
      "0.204 perplexity: 81.164 speed: 8090 wps\n",
      "KL is 112.76165008544922\n",
      "0.304 perplexity: 78.777 speed: 8075 wps\n",
      "KL is 112.68804931640625\n",
      "0.404 perplexity: 78.610 speed: 8079 wps\n",
      "KL is 112.59851837158203\n",
      "0.504 perplexity: 77.787 speed: 8082 wps\n",
      "KL is 112.62954711914062\n",
      "0.604 perplexity: 76.182 speed: 8082 wps\n",
      "KL is 112.57152557373047\n",
      "0.703 perplexity: 75.676 speed: 8080 wps\n",
      "KL is 112.51260375976562\n",
      "0.803 perplexity: 74.954 speed: 8080 wps\n",
      "KL is 112.43880462646484\n",
      "0.903 perplexity: 73.340 speed: 8083 wps\n",
      "KL is 112.42378997802734\n",
      "Epoch: 7 Train Perplexity: 72.561\n",
      "Epoch: 7 Valid Perplexity: 102.448\n",
      "Epoch: 8 Learning rate: 0.062\n",
      "0.000 perplexity: 116.927 speed: 6993 wps\n",
      "KL is 112.32495880126953\n",
      "0.004 perplexity: 90.073 speed: 8046 wps\n",
      "KL is 112.3741455078125\n",
      "0.104 perplexity: 70.350 speed: 8044 wps\n",
      "KL is 112.34636688232422\n",
      "INFO:tensorflow:Model/global_step/sec: 19.9167\n",
      "INFO:tensorflow:Recording summary at step 16698.\n",
      "0.204 perplexity: 77.928 speed: 8060 wps\n",
      "KL is 112.3124771118164\n",
      "0.304 perplexity: 75.649 speed: 8064 wps\n",
      "KL is 112.31472778320312\n",
      "0.404 perplexity: 75.444 speed: 8070 wps\n",
      "KL is 112.32982635498047\n",
      "0.504 perplexity: 74.653 speed: 8071 wps\n",
      "KL is 112.30723571777344\n",
      "0.604 perplexity: 73.145 speed: 8074 wps\n",
      "KL is 112.26702117919922\n",
      "0.703 perplexity: 72.638 speed: 8077 wps\n",
      "KL is 112.23633575439453\n",
      "0.803 perplexity: 71.903 speed: 8081 wps\n",
      "KL is 112.18150329589844\n",
      "0.903 perplexity: 70.360 speed: 8080 wps\n",
      "KL is 112.21633911132812\n",
      "Epoch: 8 Train Perplexity: 69.608\n",
      "Epoch: 8 Valid Perplexity: 100.700\n",
      "Epoch: 9 Learning rate: 0.031\n",
      "0.000 perplexity: 103.774 speed: 7771 wps\n",
      "KL is 112.17375946044922\n",
      "0.004 perplexity: 89.152 speed: 8240 wps\n",
      "KL is 112.13628387451172\n",
      "0.104 perplexity: 68.792 speed: 8057 wps\n",
      "KL is 112.18099212646484\n",
      "0.204 perplexity: 76.094 speed: 8077 wps\n",
      "KL is 112.13301086425781\n",
      "INFO:tensorflow:Model/global_step/sec: 19.9333\n",
      "INFO:tensorflow:Recording summary at step 19089.\n",
      "0.304 perplexity: 73.930 speed: 8070 wps\n",
      "KL is 112.16000366210938\n",
      "0.404 perplexity: 73.770 speed: 8075 wps\n",
      "KL is 112.07518768310547\n",
      "0.504 perplexity: 72.976 speed: 8079 wps\n",
      "KL is 112.1141586303711\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.604 perplexity: 71.487 speed: 8079 wps\n",
      "KL is 112.13560485839844\n",
      "0.703 perplexity: 70.995 speed: 8077 wps\n",
      "KL is 112.07653045654297\n",
      "0.803 perplexity: 70.280 speed: 8078 wps\n",
      "KL is 112.09648895263672\n",
      "0.903 perplexity: 68.746 speed: 8080 wps\n",
      "KL is 112.0745620727539\n",
      "Epoch: 9 Train Perplexity: 68.009\n",
      "Epoch: 9 Valid Perplexity: 99.718\n",
      "Epoch: 10 Learning rate: 0.016\n",
      "0.000 perplexity: 90.702 speed: 7682 wps\n",
      "KL is 112.0698471069336\n",
      "0.004 perplexity: 85.808 speed: 7856 wps\n",
      "KL is 112.10103607177734\n",
      "0.104 perplexity: 67.857 speed: 8062 wps\n",
      "KL is 112.10498809814453\n",
      "0.204 perplexity: 75.096 speed: 8064 wps\n",
      "KL is 112.0870361328125\n",
      "INFO:tensorflow:Model/global_step/sec: 19.9084\n",
      "INFO:tensorflow:Recording summary at step 21479.\n",
      "0.304 perplexity: 72.865 speed: 8065 wps\n",
      "KL is 112.0393295288086\n",
      "0.404 perplexity: 72.807 speed: 8074 wps\n",
      "KL is 112.09526062011719\n",
      "0.504 perplexity: 72.034 speed: 8074 wps\n",
      "KL is 112.0924301147461\n",
      "0.604 perplexity: 70.612 speed: 8078 wps\n",
      "KL is 112.05937957763672\n",
      "0.703 perplexity: 70.063 speed: 8079 wps\n",
      "KL is 112.07566833496094\n",
      "0.803 perplexity: 69.380 speed: 8081 wps\n",
      "KL is 112.08545684814453\n",
      "0.903 perplexity: 67.849 speed: 8080 wps\n",
      "KL is 112.02644348144531\n",
      "Epoch: 10 Train Perplexity: 67.162\n",
      "Epoch: 10 Valid Perplexity: 99.213\n",
      "Epoch: 11 Learning rate: 0.008\n",
      "0.000 perplexity: 123.409 speed: 6985 wps\n",
      "KL is 112.02046203613281\n",
      "0.004 perplexity: 85.996 speed: 7769 wps\n",
      "KL is 112.0817642211914\n",
      "0.104 perplexity: 67.216 speed: 8093 wps\n",
      "KL is 111.9701919555664\n",
      "0.204 perplexity: 74.494 speed: 8092 wps\n",
      "KL is 112.02540588378906\n",
      "INFO:tensorflow:Saving checkpoint to path tensorboard/model.ckpt\n",
      "INFO:tensorflow:Recording summary at step 23870.\n",
      "0.304 perplexity: 72.291 speed: 8030 wps\n",
      "KL is 112.01083374023438\n",
      "0.404 perplexity: 72.300 speed: 8041 wps\n",
      "KL is 112.05270385742188\n",
      "0.504 perplexity: 71.485 speed: 8049 wps\n",
      "KL is 112.02420043945312\n",
      "0.604 perplexity: 70.157 speed: 8053 wps\n",
      "KL is 111.99243927001953\n",
      "0.703 perplexity: 69.609 speed: 8060 wps\n",
      "KL is 111.97142028808594\n",
      "0.803 perplexity: 68.920 speed: 8064 wps\n",
      "KL is 111.9666976928711\n",
      "0.903 perplexity: 67.376 speed: 8066 wps\n",
      "KL is 112.01941680908203\n",
      "Epoch: 11 Train Perplexity: 66.715\n",
      "Epoch: 11 Valid Perplexity: 99.086\n",
      "Epoch: 12 Learning rate: 0.004\n",
      "0.000 perplexity: 103.460 speed: 7866 wps\n",
      "KL is 112.01818084716797\n",
      "0.004 perplexity: 84.028 speed: 8076 wps\n",
      "KL is 112.03929901123047\n",
      "0.104 perplexity: 66.956 speed: 8077 wps\n",
      "KL is 112.01687622070312\n",
      "0.204 perplexity: 74.180 speed: 8069 wps\n",
      "KL is 111.97514343261719\n",
      "INFO:tensorflow:Recording summary at step 26256.\n",
      "0.304 perplexity: 72.048 speed: 8055 wps\n",
      "KL is 111.9974136352539\n",
      "0.404 perplexity: 72.124 speed: 8052 wps\n",
      "KL is 112.05007934570312\n",
      "0.504 perplexity: 71.336 speed: 8058 wps\n",
      "KL is 112.00308990478516\n",
      "0.604 perplexity: 69.979 speed: 8067 wps\n",
      "KL is 112.06607055664062\n",
      "0.703 perplexity: 69.454 speed: 8070 wps\n",
      "KL is 112.08475494384766\n",
      "0.803 perplexity: 68.741 speed: 8073 wps\n",
      "KL is 112.08727264404297\n",
      "0.903 perplexity: 67.223 speed: 8073 wps\n",
      "KL is 112.06195068359375\n",
      "Epoch: 12 Train Perplexity: 66.552\n",
      "Epoch: 12 Valid Perplexity: 98.932\n",
      "Epoch: 13 Learning rate: 0.002\n",
      "0.000 perplexity: 118.861 speed: 6889 wps\n",
      "KL is 112.01956176757812\n",
      "0.004 perplexity: 81.165 speed: 7783 wps\n",
      "KL is 111.98167419433594\n",
      "0.104 perplexity: 67.051 speed: 8069 wps\n",
      "KL is 111.96552276611328\n",
      "0.204 perplexity: 74.182 speed: 8062 wps\n",
      "KL is 112.01422882080078\n",
      "0.304 perplexity: 71.918 speed: 8069 wps\n",
      "KL is 112.00042724609375\n",
      "INFO:tensorflow:Recording summary at step 28645.\n",
      "0.404 perplexity: 72.017 speed: 8071 wps\n",
      "KL is 111.91899871826172\n",
      "0.504 perplexity: 71.227 speed: 8071 wps\n",
      "KL is 111.98045349121094\n",
      "0.604 perplexity: 69.858 speed: 8075 wps\n",
      "KL is 112.01286315917969\n",
      "0.703 perplexity: 69.331 speed: 8073 wps\n",
      "KL is 111.98625183105469\n",
      "0.803 perplexity: 68.619 speed: 8072 wps\n",
      "KL is 111.98152923583984\n",
      "0.903 perplexity: 67.121 speed: 8071 wps\n",
      "KL is 112.01779174804688\n",
      "Epoch: 13 Train Perplexity: 66.473\n",
      "Epoch: 13 Valid Perplexity: 98.854\n",
      "INFO:tensorflow:Recording summary at step 30199.\n",
      "Test Perplexity: 95.597\n",
      "Saving model to tensorboard/.\n"
     ]
    }
   ],
   "source": [
    "#Running the model. No changes needed here.\n",
    "brnn_model.main(model_select = model,\n",
    "                   dat_path = data_path,\n",
    "                   sav_path = save_path,\n",
    "                   mixing_pi = mix_pi,\n",
    "                   prior_log_sigma1 = log_sigma1,\n",
    "                   prior_log_sigma2 = log_sigma2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
